﻿#include "ColliderDefinitions.cginc"
#include "ContactHandling.cginc"
#include "DistanceFunctions.cginc"
#include "Simplex.cginc"
#include "Bounds.cginc"
#include "SolverParameters.cginc"
#include "Optimization.cginc"

#pragma kernel GenerateContacts

StructuredBuffer<float4> positions;
StructuredBuffer<quaternion> orientations;
StructuredBuffer<float4> principalRadii;
StructuredBuffer<float4> velocities;

StructuredBuffer<int> simplices;
StructuredBuffer<aabb> simplexBounds; // bounding box of each simplex.    

StructuredBuffer<transform> transforms;
StructuredBuffer<shape> shapes;

// edge mesh data:
StructuredBuffer<EdgeMeshHeader> edgeMeshHeaders;
StructuredBuffer<BIHNode> edgeBihNodes;
StructuredBuffer<Edge> edges;
StructuredBuffer<float2> edgeVertices;

StructuredBuffer<uint2> contactPairs;
StructuredBuffer<int> contactOffsetsPerType;

RWStructuredBuffer<contact> contacts;
RWStructuredBuffer<uint> dispatchBuffer;

StructuredBuffer<transform> worldToSolver;

uint maxContacts; 
float deltaTime;

[numthreads(128, 1, 1)]
void GenerateContacts (uint3 id : SV_DispatchThreadID) 
{
    uint i = id.x;

    // entry #11 in the dispatch buffer is the amount of pairs for the first shape type.
    if (i >= dispatchBuffer[11 + 4 * EDGE_MESH_SHAPE]) return; 

    int firstPair = contactOffsetsPerType[EDGE_MESH_SHAPE];
    int simplexIndex = contactPairs[firstPair + i].x;
    int colliderIndex = contactPairs[firstPair + i].y;
    shape s = shapes[colliderIndex];

    if (s.dataIndex < 0) return;

    EdgeMeshHeader header = edgeMeshHeaders[s.dataIndex];
    
    EdgeMesh meshShape;
    meshShape.colliderToSolver = worldToSolver[0].Multiply(transforms[colliderIndex]);
    meshShape.s = s;

    // invert a full matrix here to accurately represent collider bounds scale.
    float4x4 solverToCollider = Inverse(TRS(meshShape.colliderToSolver.translation.xyz, meshShape.colliderToSolver.rotation, meshShape.colliderToSolver.scale.xyz));
    aabb simplexBound = simplexBounds[simplexIndex].Transformed(solverToCollider);

    float4 marginCS = float4((s.contactOffset + collisionMargin) / meshShape.colliderToSolver.scale.xyz, 0);

    int simplexSize;
    int simplexStart = GetSimplexStartAndSize(simplexIndex, simplexSize);

    int stack[12]; 
    int stackTop = 0;

    stack[stackTop++] = 0;

    while (stackTop > 0)
    {
        // pop node index from the stack:
        int nodeIndex = stack[--stackTop];
        BIHNode node = edgeBihNodes[header.firstNode + nodeIndex];

        // leaf node:
        if (node.firstChild < 0)
        {
            // check for contact against all triangles:
            for (int dataOffset = node.start; dataOffset < node.start + node.count; ++dataOffset)
            {
                Edge t = edges[header.firstEdge + dataOffset];
                float4 v1 = float4(edgeVertices[header.firstVertex + t.i1],0,0) + s.center;
                float4 v2 = float4(edgeVertices[header.firstVertex + t.i2],0,0) + s.center;
                aabb edgeBounds;
                edgeBounds.FromEdge(v1, v2, marginCS);

                if (edgeBounds.IntersectsAabb(simplexBound, s.is2D()))
                {
                    float4 simplexBary = BarycenterForSimplexOfSize(simplexSize);
                    float4 simplexPoint;
                   
                    meshShape.edge.Cache(v1 * meshShape.colliderToSolver.scale, v2 * meshShape.colliderToSolver.scale);

                    SurfacePoint surfacePoint = Optimize(meshShape, positions, orientations, principalRadii,
                                                         simplices, simplexStart, simplexSize, simplexBary, simplexPoint, surfaceCollisionIterations, surfaceCollisionTolerance);

                    float4 velocity = FLOAT4_ZERO;
                    float simplexRadius = 0;
                    for (int j = 0; j < simplexSize; ++j)
                    {
                        int particleIndex = simplices[simplexStart + j];
                        simplexRadius += principalRadii[particleIndex].x * simplexBary[j];
                        velocity += velocities[particleIndex] * simplexBary[j];
                    }

                    float dAB = dot(simplexPoint - surfacePoint.pos, surfacePoint.normal);
                    float vel = dot(velocity, surfacePoint.normal); 

                    if (vel * deltaTime + dAB <= simplexRadius + s.contactOffset + collisionMargin)
                    {
                        uint count = contacts.IncrementCounter();
                        if (count < maxContacts)
                        {
                            contact c = (contact)0;

                            c.pointB = surfacePoint.pos;
                            c.normal = surfacePoint.normal * meshShape.s.isInverted();
                            c.pointA = simplexBary;
                            c.bodyA = simplexIndex;
                            c.bodyB = colliderIndex;

                            contacts[count] = c;
                            
                            InterlockedMax(dispatchBuffer[0],(count + 1) / 128 + 1);
                            InterlockedMax(dispatchBuffer[3], count + 1);
                        }
                    }
                }
            }
        }
        else // check min and/or max children:
        {
            // visit min node:
            if (simplexBound.min_[node.axis] <= node.min_)
                stack[stackTop++] = node.firstChild;

            // visit max node:
            if (simplexBound.max_[node.axis] >= node.max_)
                stack[stackTop++] = node.firstChild + 1;
        }
    }
}